package cn.lingyangwl.framework.security.ratelimit;

import cn.lingyangwl.framework.tool.core.DateUtils;
import cn.lingyangwl.framework.tool.core.exception.BizException;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

/**
 * @author shenguangyang
 */
@Component
public class RateLimitManager {
    private static final String BLACKLIST_KEY = "blacklist";

    @Resource
    private RedisTemplate<String, Object> redisTemplate;

    @Resource
    private RateLimitProperties rateLimitProperties;


    public boolean blacklistEnabled(LimitTypeEnum limitType) {
        RateLimitProperties.Blacklist blacklist = rateLimitProperties.getBlacklist();
        boolean enabled = blacklist.isEnabled();
        Duration limitTime = blacklist.getLimitTime();
        return (limitTime.getSeconds() <= 0 || LimitTypeEnum.IP != limitType) && !enabled;
    }

    /**
     * 将某个ip添加到黑名单中
     *
     * @param key 缓存key
     */
    public void addBlacklist(String key, LimitTypeEnum limitType) {
        if (blacklistEnabled(limitType)) {
            return;
        }
        Duration limitTime = rateLimitProperties.getBlacklist().getLimitTime();
        String blacklistKey = key + ":" + BLACKLIST_KEY;
        redisTemplate.opsForValue().set(blacklistKey, System.currentTimeMillis(), limitTime.getSeconds(), TimeUnit.SECONDS);
    }

    /**
     * 获取某个ip是否在黑名单中
     *
     * @param key key应该包含ip
     */
    public void checkBlacklist(String key, RateLimit rateLimit) {
        if (blacklistEnabled(rateLimit.limitType())) {
            return;
        }
        Duration limitTime = rateLimitProperties.getBlacklist().getLimitTime();
        String blacklistKey = key + ":" + BLACKLIST_KEY;
        Object value = redisTemplate.opsForValue().get(blacklistKey);
        if (!Objects.isNull(value)) {
            String remainingTime = DateUtils.distance(System.currentTimeMillis(), (Long) value + limitTime.getSeconds() * 1000L, true);
            throw new BizException(rateLimit.msg().replace(RateLimitCons.REMAINING_TIME_ARG, remainingTime));
        }
    }
}
